跳到主要内容

chroma 向量数据库

快速使用

import chromadb
client = chromadb.Client()

collection = client.create_collection("sample_collection")

# Add docs to the collection. Can also update and delete. Row-based API coming soon!
collection.add(
documents=["This is document1", "This is document2"], # we embed for you, or bring your own
metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on arbitrary metadata!
ids=["doc1", "doc2"], # must be unique for each doc
)

results = collection.query(
query_texts=["This is a query document"],
n_results=2,
# where={"metadata_field": "is_equal_to_this"}, # optional filter
# where_document={"$contains":"search_string"} # optional filter
)

all-MiniLM-L6-v2 模型

可以注意到启动服务的时候,自动下载了一个 all-MiniLM-L6-v2 模型,它的作用是把输入的文本转换成向量

封装一个通用的存储类,用于存储任务的结果,这里使用的是 all-MiniLM-L6-v2 模型

import chromadb
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
import logging

class DefaultResultsStorage:
def __init__(self, results_store_name):
logging.getLogger('chromadb').setLevel(logging.ERROR)
# Create Chroma collection
chroma_persist_dir = "chroma"
chroma_client = chromadb.PersistentClient(
settings=chromadb.config.Settings(
persist_directory=chroma_persist_dir,
)
)

metric = "cosine"
embedding_function = ONNXMiniLM_L6_V2()
self.collection = chroma_client.get_or_create_collection(
name=results_store_name,
metadata={"hnsw:space": metric},
embedding_function=embedding_function,
)

def add(self, task: Dict, result: str, result_id: str):
# Continue with the rest of the function

embeddings = None
if (
len(self.collection.get(
ids=[result_id], include=[])["ids"]) > 0
): # Check if the result already exists
self.collection.update(
ids=result_id,
embeddings=embeddings,
documents=result,
metadatas={"task": task["task_name"], "result": result},
)
else:
self.collection.add(
ids=result_id,
embeddings=embeddings,
documents=result,
metadatas={"task": task["task_name"], "result": result},
)

def query(self, query: str, top_results_num: int) -> List[dict]:
count: int = self.collection.count()
if count == 0:
return []
results = self.collection.query(
query_texts=query,
n_results=min(top_results_num, count),
include=["metadatas"]
)
return [item["task"] for item in results["metadatas"][0]]

Reference